{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "## 使用huggingface对thuglm-6b进行训练"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"thuglm\", trust_remote_code=True)\n",
    "model = AutoModel.from_pretrained(\"thuglm\", trust_remote_code=True).half().cuda() # , device_map='auto'\n",
    "# response, history = model.chat(tokenizer, \"你好\", history=[])\n",
    "# print(response)\n",
    "# response, history = model.chat(tokenizer, \"晚上睡不着应该怎么办\", history=history)\n",
    "# print(response)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, DatasetDict\n",
    "from glob import glob\n",
    "import random\n",
    "random.seed(42)\n",
    "\n",
    "all_file_list = glob(pathname=\"D:/数据/train_gpt2/gpt2_data/*/**\")[:10]\n",
    "test_file_list = random.sample(all_file_list, 3)\n",
    "train_file_list = [i for i in all_file_list if i not in test_file_list]\n",
    "\n",
    "len(train_file_list), len(test_file_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_datasets =load_dataset(\"csv\",data_files={'train':train_file_list,'valid':test_file_list}, cache_dir=\"cache_data\")\n",
    "\n",
    "raw_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "context_length = 64\n",
    "def tokenize(element):\n",
    "    outputs = tokenizer(\n",
    "        element[\"content\"],\n",
    "        truncation=True,\n",
    "        max_length=context_length,\n",
    "        return_overflowing_tokens=True,\n",
    "        return_length=True,\n",
    "    )\n",
    "    input_batch = []\n",
    "    for length, input_ids in zip(outputs[\"length\"], outputs[\"input_ids\"]):\n",
    "        if length == context_length:\n",
    "            input_batch.append(input_ids)\n",
    "    return {\"input_ids\": input_batch}\n",
    "\n",
    "\n",
    "tokenized_datasets = raw_datasets.map(\n",
    "    tokenize, batched=True, remove_columns=raw_datasets[\"train\"].column_names\n",
    ")\n",
    "tokenized_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForLanguageModeling\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Trainer, TrainingArguments\n",
    "\n",
    "args = TrainingArguments(\n",
    "    output_dir=\"test001\",\n",
    "    per_device_train_batch_size=1,\n",
    "    per_device_eval_batch_size=1,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    eval_steps=2_0,\n",
    "    logging_steps=2_0,\n",
    "    gradient_accumulation_steps=8,\n",
    "    num_train_epochs=2,\n",
    "    weight_decay=0.1,\n",
    "    warmup_steps=1_000,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    learning_rate=5e-4,\n",
    "    save_steps=2_0,\n",
    "    fp16=True,\n",
    "    push_to_hub=False,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    args=args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"valid\"],\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mynet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
